# @Time   : 2022/3/12
# @Author : Zihan Lin
# @Email  : zhlin@ruc.edu.cn
"""
recbole_cdr.model.crossdomain_recommender
##################################
"""
from recbole.model.abstract_recommender import AbstractRecommender
import numpy as np
from recbole_cdr.utils import ModelType
import torch
from scipy.sparse import coo_matrix
import scipy.sparse as sp
from generative_model.deleter import Deleter,train2
#from generative_model.deleter_shapley import Deleter_shapley,train2


from generative_model.generator import Generator,train1


class CrossDomainRecommender(AbstractRecommender):
    """This is a abstract cross-domain recommender. All the cross-domain model should implement this class.
    The base cross-domain recommender class provide the basic dataset and parameters information.
    """
    type = ModelType.CROSSDOMAIN
    def __init__(self, config, dataset):
        super(CrossDomainRecommender, self).__init__()
        self.config=config
        # load source dataset info
        self.SOURCE_USER_ID = dataset.source_domain_dataset.uid_field
        self.SOURCE_ITEM_ID = dataset.source_domain_dataset.iid_field
        self.SOURCE_LABEL='source_label'
        self.TARGET_LABEL='target_label'

#        self.SOURCE_ITEM_FEAT =dataset.source_domain_dataset.item_feat_frame
#        self.SOURCE_ITEM_FEAT = dataset.source_domain_dataset.user_feat_frame
        self.SOURCE_NEG_ITEM_ID = config['source_domain']['NEG_PREFIX'] + self.SOURCE_ITEM_ID
        self.source_num_users = dataset.source_domain_dataset.num(self.SOURCE_USER_ID)
        self.source_num_items = dataset.source_domain_dataset.num(self.SOURCE_ITEM_ID)
        # load target dataset info
        self.TARGET_USER_ID = dataset.target_domain_dataset.uid_field
        self.TARGET_ITEM_ID = dataset.target_domain_dataset.iid_field
        self.TARGET_NEG_ITEM_ID = config['target_domain']['NEG_PREFIX'] + self.TARGET_ITEM_ID
        self.target_num_users = dataset.target_domain_dataset.num(self.TARGET_USER_ID)
        self.target_num_items = dataset.target_domain_dataset.num(self.TARGET_ITEM_ID)
        # load both dataset info
        self.total_num_users = dataset.num_total_user
        self.total_num_items = dataset.num_total_item
        self.overlapped_num_users = dataset.num_overlap_user
        self.overlapped_num_items = dataset.num_overlap_item
        self.OVERLAP_ID = dataset.overlap_id_field
        # load parameters info
        self.device = config['device']
        self.source_u , self.source_i = dataset.interactions(domain='source')
        self.target_u , self.target_i = dataset.interactions(domain='target')
        if config['generate']==True:
            self.generater = Generator(config,
                                        self.source_u,
                                        self.source_i,
                                        self.target_u,
                                        self.target_i,
                                        self.total_num_users,
                                        self.total_num_items,
                                        self.overlapped_num_users,
                                        self.source_num_users,
                                        self.source_num_items,
                                        self.target_num_users,
                                        self.target_num_items
                                        ).to(config['device'])
            train1(config, self.generater)
            self.source_u, self.source_i = self.generater.data_reproduce()
            del self.generater

        if self.config['delete']==True:
            self.deleter=Deleter(config,
                                self.source_u,
                                self.source_i,
                                self.target_u,
                                self.target_i,
                                self.total_num_users,
                                self.total_num_items,
                                self.overlapped_num_users,
                                self.source_num_users,
                                self.source_num_items,
                                self.target_num_users,
                                self.target_num_items,
                                dataset
                                ).to(config['device'])
            train2(config, self.deleter)
            if config['delete_domain']=='source':
               self.source_u, self.source_i = self.deleter.data_reproduce()
            elif config['delete_domain']=='both':
                self.source_u, self.source_i,self.target_u,self.target_i = self.deleter.data_reproduce()
            del self.deleter

        self.source_interaction_matrix = self.inter_matrix(domain='source').astype(np.float32)
        self.target_interaction_matrix = self.inter_matrix(domain='target').astype(np.float32)
    def inter_matrix(self,domain='source'):
        if domain == 'source':
            return self.get_sparse_matrix(self.total_num_users, self.total_num_items, self.source_u.cpu(), self.source_i.cpu())
        else:
            return self.get_sparse_matrix(self.total_num_users, self.total_num_items, self.target_u.cpu(), self.target_i.cpu())
    def get_sparse_matrix(self, user_num, item_num, src, tgt):
        data = np.ones(len(src))
        mat = coo_matrix((data, (src, tgt)), shape=(user_num, item_num))
        return mat
    def generate_negative_samples(self,total_num_items, items, domain):
        if domain=='source':
           negative_samples_indices = torch.randint(self.target_num_items, total_num_items, (len(items), 1), dtype=torch.int64).to(items.device)
           negative_samples_indices = negative_samples_indices.view(-1, 1).squeeze()
        else:
           negative_samples_indices = torch.randint(1, self.target_num_items, (len(items), 1), dtype=torch.int64).to(items.device)
           negative_samples_indices = negative_samples_indices.view(-1, 1).squeeze()
        return negative_samples_indices
    def mask_delete(self,source_user,source_item,source_label=None):
        if self.config['mask_delete'] == True:
           num_items = int(self.source_i.max().item()) + 1
           source_hash = self.source_u * num_items + self.source_i
           batch_hash = source_user * num_items + source_item
        mask = torch.isin(batch_hash.to(self.device), source_hash.to(self.device))
        source_user = source_user[mask]
        source_item = source_item[mask]
        source_label = source_label[mask]
        return source_user,source_item,source_label
    def set_phase(self, phase):
        pass
